# Some of the functions used were taken from: https://eu.udacity.com/course/self-driving-car-engineer-nanodegree--nd013
# The implemented functions have been modified and improved in order to perform some necessary tasks for the purpose of the thesis

# Import libraries
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import sys
import warnings
import timeit
import datetime
import argparse

def save(filename, image, file_extension):
    """
    ## Function used to save the image on disk preserving the format extension
    """
    
    # Saving the image
    isWritten = cv2.imwrite(filename + file_extension, image)
    if isWritten:
        print(filename + " Image successfully saved!")
    
    return

def do_canny(frame, file_extension, verbose=True):
    """
    ## After performing the conversion from RGB (3 channels) to grayscale (1 channels)
    ## it realizes the convolution of the image of the initial frame with a Gaussian kernel of size 5x5
    ## to be able to blur the image and apply Canny algorithm with certain values of min and max thresholds.
    """
    # Converts frame to grayscale because we only need the luminance channel for detecting edges - less computationally expensive
    gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    
    # show the gray image
    if verbose:    
        cv2.imshow("Gray Image", gray)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        save('Gray', gray, file_extension)
    
    # Defines a kernel size of 5x5 and applies gaussian blur with deviation of 0 to frame
    ksize = 5
    blur = cv2.GaussianBlur(gray, (ksize, ksize), 0, cv2.BORDER_DEFAULT)
    
    # show the smoothing image
    if verbose:
        cv2.imshow("Gaussian Smoothing",blur)
        cv2.waitKey(0) 
        cv2.destroyAllWindows()
        save('Gaussing Blur', blur, file_extension)
    
    # Define the Canny Thresholds
    low_threshold = 50
    high_threshold = 150
    # Applies Canny edge detector with Low threshold of 50 and High threshold of 150
    canny = cv2.Canny(blur, low_threshold, high_threshold)
    
    # show the canny image
    if verbose:
        cv2.imshow("canny", canny)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        save('Canny', canny, file_extension)
    
    # Return the binary image
    return canny

def do_segment(frame, img, file_extension, verbose=True):
    """
    ## Create a mask with all the values at 0 (completely black) of the same shape as the canny image obtained previously,
    ## filling the pixels of the quadrilateral polygon, obtained through the vertices representing the ROI (region of interest to identify lane lines)
    ## with the value of 255 (completely white) by performing a bitwise logical AND operation to detect lane lines within the quadrilateral polygon and finally return this filtered image.
    """
    
    # Since an image is a multi-directional array containing the relative intensities of each pixel in the image, we can use frame.shape to return a tuple: [number of rows (height), number of columns (width), number of channels]
    # Since height begins from 0 at the top, the y-coordinate of the bottom of the frame is its height
    height = frame.shape[0]
    # Creates a quadrilateral polygon (ROI) for the mask defined by four coordinates of (x, y)
    vertices = np.array([[(255, height), 
                          (605, 395),     
                          (735, 395),   
                          (1037, height)]], 
                        dtype=np.int32)
    
    # show and save the ROI applied on the original image frame
    if verbose:
        img_mod = cv2.polylines(img, [vertices], True, (255,0,0),3)
        cv2.imshow('ROI', img_mod)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        save('ROI', img_mod, file_extension)
        
    # Creates an image filled with zero intensities (i.e. all pixels are black) with the same dimensions as the frame
    mask = np.zeros_like(frame)
    # Allows the mask to be filled with values of 1 (i.e. 255=white) and the other areas to be filled with values of 0 (black)
    match_mask_color = 255
    cv2.fillPoly(mask, vertices, match_mask_color)
    # A bitwise AND operation between the mask and frame keeps only the quadrilateral area of the frame
    segment = cv2.bitwise_and(frame, mask)
    
    # show the cropped image
    if verbose:
        cv2.imshow("Cropped", segment)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        save('Cropped',segment, file_extension)
    
    return segment


def do_hough(frame):
    """
    ## It finds line segments in a binary image using the probabilistic Hough transform with pre-defined parameters.
    """
    
    rho = 2 # distance resolution in pixels of the Hough grid (larger = less precision)
    theta = np.pi/180 # angular resolution in radians of the Hough grid (larger = less precision)
    
    # meaning at least 100 points in image space need to be associated wich each line segment
    threshold = 100 # minimum number of votes (intersections in Hough grid cell)
    
    min_line_length = 50 # minimum number of pixels making up a line 
    max_line_gap = 100 # maximum distance in pixels between disconnected lines
    
    # Compute the Hough Transformation               
    hough = cv2.HoughLinesP(frame, rho, theta, threshold, np.array([]), min_line_length, max_line_gap)
    
    return hough

def calculate_lines(frame, lines):
    """
    ## It fits a polynomial first order to the line's coordinate points returning a vector of coefficients
    ## that represent the slope and also the y-intercept.
    ## Then it distinguishes the right line from the left one based on the slope value (< for left and > for right)
    ## Finally, it averages the slope and intercept values of right and left lines and calculates the initial and final coordinate points
    """
    
    # Empty arrays to store the coordinates of the left and right lines
    left = []
    right = []
    # It may not work properly if the lines are broken
    if lines is None:
        print("Caution! Cannot find the lines: missing or broken lane lines!\n")
    
    # Loops through every detected line
    for line in lines:
        # Reshapes line from 2D array to 1D array
        x1, y1, x2, y2 = line.reshape(4)
        # Fits a linear polynomial (of order 1) to the x and y coordinates and returns a vector of coefficients which describe the slope and y-intercept
        deg = 1
        parameters = np.polyfit((x1, x2), (y1, y2), deg)
        slope = parameters[0]
        y_intercept = parameters[1]
        # If slope is negative, the line is to the left of the lane, and otherwise, the line is to the right of the lane
        if slope < 0:
            left.append((slope, y_intercept))
        else:
            right.append((slope, y_intercept))
    # Averages out all the values for left and right into a single slope and y-intercept value for each line
    left_avg = np.average(left, axis = 0)
    right_avg = np.average(right, axis = 0)
    
    # left or right line are none when the algorithm fails to detect either the left lane or the right lane or both in a image frame
    if not left:
        print("Caution: missing or broken left lane line!\n")
    if not right:
        print("Caution: missing or broken right lane line!\n")
    if not left and not right:
        print("Caution: missing or broken lane lines!\n")
    
    # Calculates the x1, y1, x2, y2 coordinates for the left and right lines
    try:    
        left_line = calculate_coordinates(frame, left_avg)
        right_line = calculate_coordinates(frame, right_avg)
        return np.array([left_line, right_line])
    
    except Exception as ex:
        print(ex, '\n')
    #print error to console
        return None

def calculate_coordinates(frame, parameters):
    """
    ## It Calculate the coordinate points of the lines from the slope and intercept values.
    """
    
    slope, intercept = parameters
    # Sets initial y-coordinate as height from top down (bottom of the frame)
    y1 = frame.shape[0]
    # Sets final y-coordinate as 2/3 above the bottom of the frame
    y2 = int(y1*2/3)
    # Sets initial x-coordinate as (y1 - b) / m since y1 = mx1 + b
    x1 = int((y1 - intercept) / slope)
    # Sets final x-coordinate as (y2 - b) / m since y2 = mx2 + b
    x2 = int((y2 - intercept) / slope)
    return np.array([x1, y1, x2, y2])

def draw_lines(frame, lines, file_extension, verbose=True):
    """
    ## It create an image with all values to 0 (black) with the same shape of the original image
    ## and starting from the coordinate points, it draw lines with the green color (0, 255, 0) and with thickness equal to 10 pixels
    """
    
    img = np.copy(frame)
    # Creates an image filled with zero intensities (black) with the same dimensions as the frame
    lines_image = np.zeros_like(img)
    # Checks if any lines are detected
    if lines is not None:
        for x1, y1, x2, y2 in lines:
            # Draws lines between two coordinates with green color and 5 thickness
            cv2.line(lines_image, (x1, y1), (x2, y2), (0, 255, 0), thickness=10)
    else:
        print("\nLines not found!")
    
    # show the lines    
    if verbose:    
        cv2.imshow("Hough", lines_image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        save('Hough', lines_image, file_extension)
    
    return lines_image

def usage():
    print ('Usage: '+sys.argv[0] + ' <image_file> or <video_file>')
    print ('Example: python SLLD.py -i=image.png (or -i=input_video.avi)')

def main(args):
        
    if len(args) == 1:
        
        pathfile = args['input']
        
        print(f"""The passed parameter is: {pathfile}""")
        
        # Check the extension
        if pathfile.lower().endswith(('.png', '.jpg', '.jpeg')):
            # Check if it is a file
            if not os.path.isfile(pathfile):
                print("Input image file ", pathfile, " doesn't exist")
                usage()
                sys.exit()
            
            filename, file_extension = os.path.splitext(pathfile)
            print('filename: ', filename)
            print('file_extension: ', file_extension)
            
            # NOTE: OpenCV uses BGR as its default colour order for representing images, matplotlib uses RGB color space instead
            image = cv2.imread(pathfile)
            
            # # or using matplotlib:
            # # img=mpimg.imread(pathfile)
            # # imgplot = plt.imshow(img)
            # # plt.show()
            
            if image is None:
                print("Unable to read file. Exiting...")
                sys.exit()
            
            # Take the height, width and the number of channels in the image
            height = image.shape[0]
            width = image.shape[1]
            channels = image.shape[2]

            print('Image Height       : ',height)
            print('Image Width        : ',width)
            print('Number of Channels : ',channels)
    
            # getting tick count
            e1 = cv2.getTickCount()
            
            # check if optimization is enabled (default value = True)
            print ("Optimized: ", cv2.useOptimized())
            
            # Resizing the image into HD resolution
            print("Resizing the input image...")
            image = cv2.resize(image, dsize=(1280, 720))
            (H, W) = image.shape[:2]
            print('Shape of the resized image-> Width x Height:', W, 'x', H)
            
            # print("Showing the image...")
            # cv2.namedWindow("Image", cv2.WINDOW_AUTOSIZE)
            # cv2.imshow("Image",image)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            # Saving the image 
            # save('Original', image, file_extension)
            
            # use find_coordinates.py script that is able to capture the mouse click coordinates on the image in order
            # to set the coordinate vertices of ROI in do_segment() function

            # other method to display an RGB image opened with opencv with matplotlib
            # plt.imshow(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
            # plt.show()

            print("Applying the Canny method...")
            canny = do_canny(image, file_extension, verbose=False)
            
            print("Applying the ROI to the image...")
            segment = do_segment(canny, image.copy(), file_extension, verbose=False)
            
            print("Applying the Hough Transformation to the image...")
            hough = do_hough(segment)
            
            # Averages multiple detected lines from hough into one line for left border of lane and one line for right border of lane
            lines = calculate_lines(image, hough)

            # Visualizes the lines
            lines_image = draw_lines(image, lines, file_extension, verbose=False)
            
            # Overlays lines on frame by taking their weighted sums and adding an arbitrary scalar value of 1 as the gamma argument
            output = cv2.addWeighted(image, 0.9, lines_image, 1, 1)
            
            save('Output', output, file_extension)
            print('Output image saved in: ', "Output" + file_extension)
            # Compute and Print the processing time -> to obtain a true value, disable any intermediate display
            e2 = cv2.getTickCount()
            time = (e2 - e1)/cv2.getTickFrequency()
            print('Total Processing Time: ', time, 'seconds')

            # Opens a new window and displays the output frame
            print("Output image...")
            cv2.imshow("output", output)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
    
            print("Finish. Exiting...")

        elif pathfile.lower().endswith(('.avi', '.mp4')): # most commond video file extensions
            
            if not os.path.isfile(pathfile):
                print("Input video file ", pathfile, " doesn't exist")
                usage()
                sys.exit()
    
            # count the number of frames
            frames=0
            # count the number of bad frames (not correctly detected)
            bad=0
            
            cap = cv2.VideoCapture(pathfile)
            
            # Take some information from the video
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            print("Width: ", width, " Height: ", height)
            frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            print("Number of total frames: ", frame_count )
            fps = cap.get(cv2.CAP_PROP_FPS)
            print("FPS: ", fps )
            duration = frame_count/fps
            print("Video Duration: ", duration, 'Seconds')
            
            # Define the codec and create VideoWriter object
            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            out = cv2.VideoWriter('output.avi', fourcc, 20.0, (1280,  720))
            
            start = timeit.default_timer()
            # e1 = cv2.getTickCount()
            
            # performs a cycle for each frame (if it exists) of the video until the end
            while (cap.isOpened()):
                
                # ret = a boolean return value from getting the frame, frame = the current frame being projected in the video
                ret, frame = cap.read()
                
                # if frame is read correctly ret is True otherwise the end of the video is reached
                if not ret:
                    print("\nCan't read frame. Reached the final frame. Exiting ...\n")
                    break
                frames+=1
                print("\nFrame: ", frames)
                
                # Resizing the frame to HD resolution
                frame = cv2.resize(frame, dsize=(1280, 720))
                # Process each image frame
                canny = do_canny(frame, 'jpg', verbose=False)
                segment = do_segment(canny, frame.copy(), 'jpg', verbose=False)
                hough = do_hough(segment)
                lines = calculate_lines(frame, hough)
                if lines is None:
                    print("Frame", frames, "not detected correctly!")
                    bad+=1
                lines_image = draw_lines(frame, lines, 'jpg', verbose=False)
                output = cv2.addWeighted(frame, 0.9, lines_image, 1, 1)
                
                out.write(output)
                # Show output
                cv2.imshow("output", output)
                # Frames are read by intervals of 10 milliseconds.
                # The programs breaks out of the while loop when the user presses the 'q' key
                if cv2.waitKey(10) & 0xFF == ord('q'):
                    break
            
            # Time execution and info
            tem = timeit.default_timer() - start
            t = str(datetime.timedelta(seconds=tem))
            print('\nTotal Processing Time: ', t, '[h:m:s]')
            # e2 = cv2.getTickCount()
            # time = (e2 - e1)/cv2.getTickFrequency()
            print('Total number of bad detections: ', bad)
            print('Total number of good detections: ', frames-bad)
            print('Total number of processed frames: ', frames)
            print('File output saved in: "output.avi"')
            # print('Total Processing Time: ', time, 'seconds')
            
            # The following frees up resources and closes all windows
            cap.release()
            out.release()
            cv2.destroyAllWindows()
            
        else :
            print("Unable to read file, format file not supported yet. Exiting...")
            usage()
            sys.exit()

if __name__ == '__main__':
    argparser = argparse.ArgumentParser(description='Simple Lane Line Detection')
    requiredNamed = argparser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-i", "--input", required=True, help="path to input image or video")
    args = vars(argparser.parse_args())
    
    main(args)